# -*- coding: utf-8 -*-
import json
from datetime import datetime

from common.cache import redis_cache
from common.order.mg_stat import mg_channel_statistic
from common.order.model import *
from common.utils import exceptions as err
from common.utils import id_generator
from common.utils import track_logging
from common.utils.decorator import sql_wrapper
from common.utils.exceptions import ParamError

_LOGGER = track_logging.getLogger(__name__)


def _get_userid_from_extra(extra):
    j = json.loads(extra)
    t = j.get('user_info', {})
    return t.get('user_id', 0)


@sql_wrapper
def create_order(data, chn_id, chn_type):
    _LOGGER.info("create order data is: %s, chn_id is: %s, chn_type is: %s", data, chn_id, chn_type)
    try:
        pay_order = PayOrder()
        pay_order.id = id_generator.generate_long_id('pay')
        pay_order.mch_id = data['mch_id']
        pay_order.out_trade_no = data['out_trade_no']
        pay_order.service = data['service']
        pay_order.status = PAY_STATUS.READY
        pay_order.sdk_version = data['sdk_version']
        pay_order.channel_id = chn_id
        pay_order.channel_type = chn_type
        pay_order.body = data['body']
        pay_order.total_fee = data['total_fee']
        pay_order.application_amount = data['total_fee']
        pay_order.notify_url = data['notify_url']
        pay_order.mch_create_ip = data['mch_create_ip']
        pay_order.sign = data['sign']
        pay_order.extra = data.get('extra')
        pay_order.notify_status = NOTIFY_STATUS.READY
        pay_order.user_id = data.get('user_id') or _get_userid_from_extra(data.get('extra'))
        pay_order.save()
        return pay_order
    except Exception as e:
        _LOGGER.exception('create_order error, %s', e)
        raise err.DataError()


@sql_wrapper
def get_pay(pay_id):
    return PayOrder.query.filter(PayOrder.id == pay_id).first()


@sql_wrapper
def fill_third_id(pay_id, third_id):
    PayOrder.query.filter(PayOrder.id == pay_id).update({
        'third_id': third_id
    })
    orm.session.commit()


@sql_wrapper
def fill_extend(pay_id, extend):
    PayOrder.query.filter(PayOrder.id == pay_id).update({
        'extend': json.dumps(extend, ensure_ascii=False)
    })
    orm.session.commit()


@sql_wrapper
def get_order(mch_id, out_trade_no):
    return PayOrder.query.filter(PayOrder.mch_id == mch_id).filter(
        PayOrder.out_trade_no == out_trade_no).first()


@sql_wrapper
def add_pay_success(mch_id, pay_id, total_fee, trade_no, extend):
    datetime_now = datetime.utcnow()
    pay = PayOrder.query.filter(PayOrder.id == pay_id).filter(
        PayOrder.status == PAY_STATUS.READY).with_lockmode('update').first()

    if pay is None:
        return
    pay.status = PAY_STATUS.SUCC
    pay.total_fee = total_fee
    pay.third_id = trade_no
    pay.extend = json.dumps(extend, ensure_ascii=False)
    pay.payed_at = datetime_now
    pay.updated_at = datetime_now
    pay.save()
    # 支付成功加入 通道统计
    try:
        redis_cache.add_chn_amount(pay.channel_id, pay.total_fee)

        # 成功支付，移除用户创建统计
        redis_cache.remove_user_pay_count(pay.mch_id, pay.user_id)
        mg_channel_statistic(pay)
    except Exception as e:
        _LOGGER.info("add_pay_success cal stats error %s", e)


@sql_wrapper
def set_pay_success(pay_id, reason, amount=None):
    amount = amount if amount else 0
    datetime_now = datetime.utcnow()
    pay = PayOrder.query.filter(PayOrder.id == pay_id).filter(
        PayOrder.status == PAY_STATUS.READY).with_lockmode('update').first()

    if pay is None:
        raise ParamError(u'订单号不存在')

    if float(amount) > float(pay.total_fee):
        raise ParamError(u'补单金额不能大于原始金额')

    pay.status = PAY_STATUS.SUCC
    pay.total_fee = amount if amount else pay.total_fee
    if pay.extend:
        j = json.loads(pay.extend)
        j['reason'] = reason
        pay.extend = json.dumps(j, ensure_ascii=False)
    else:
        pay.extend = json.dumps({'reason': reason})
    pay.payed_at = datetime_now
    pay.updated_at = datetime_now
    pay.save()
    # 支付成功加入 通道统计
    redis_cache.add_chn_amount(pay.channel_id, pay.total_fee)
    mg_channel_statistic(pay)


@sql_wrapper
def record_manual_order(pay_id, operator, reason):
    if ManualOrder.query.filter(ManualOrder.id == pay_id).first():
        return
    pay = get_pay(pay_id)
    morder = ManualOrder()
    morder.id = pay.id
    morder.mch_id = pay.mch_id
    morder.out_trade_no = pay.out_trade_no
    morder.service = pay.service
    morder.status = pay.status
    morder.channel_id = pay.channel_id
    morder.total_fee = pay.total_fee
    morder.notify_status = pay.notify_status
    morder.user_id = pay.user_id
    morder.payed_at = pay.payed_at
    morder.notified_at = pay.notified_at
    morder.created_at = pay.created_at
    morder.updated_at = pay.updated_at
    morder.notify_status = pay.notify_status
    morder.operator_id = operator.id
    morder.operator = operator.nickname
    morder.reason = reason
    morder.save()


@sql_wrapper
def add_pay_fail(mch_id, pay_id, total_fee, trade_no, extend):
    datetime_now = datetime.utcnow()
    res = PayOrder.query.filter(PayOrder.id == pay_id).filter(
        PayOrder.status == PAY_STATUS.READY).update({
        'status': PAY_STATUS.FAIL,
        'total_fee': total_fee,
        'third_id': trade_no,
        'extend': json.dumps(extend, ensure_ascii=False),
        'payed_at': datetime_now,
        'updated_at': datetime_now
    })
    if res:
        orm.session.commit()
        return True
    else:
        _LOGGER.warn('add_pay_fail, cocurrency occured! pay_id[%s]', pay_id)
        return False


@sql_wrapper
def add_notify_success(pay_id):
    pay_order = PayOrder.query.filter(PayOrder.id == pay_id).with_lockmode('update').first()
    pay_order.notify_status = NOTIFY_STATUS.SUCC
    pay_order.notify_count += 1
    pay_order.notified_at = datetime.utcnow()
    pay_order.save()
    return pay_order


@sql_wrapper
def add_notify_fail(pay_id):
    pay_order = PayOrder.query.filter(PayOrder.id == pay_id).with_lockmode('update').first()
    pay_order.notify_status = NOTIFY_STATUS.FAIL
    pay_order.notify_count += 1
    pay_order.notify_at = datetime.utcnow()
    pay_order.save()
    return pay_order


@sql_wrapper
def get_total_mount(channel_id, start_time, end_time):
    items = PayOrder.query.filter(PayOrder.channel_id == channel_id).filter(PayOrder.status == PAY_STATUS.SUCC). \
        filter(PayOrder.payed_at >= start_time).filter(PayOrder.payed_at < end_time).all()
    total = 0.0
    for item in items:
        total += float(item.total_fee)
    return total


@sql_wrapper
def get_orders_by_channel(channel_id, start_time, end_time):
    return PayOrder.query.filter(PayOrder.channel_id == channel_id).filter(PayOrder.status == PAY_STATUS.READY). \
        filter(PayOrder.created_at >= start_time).filter(PayOrder.created_at < end_time).all()
